import warnings
import argparse

import os
from pathlib import Path

import yaml

import dreamerv3
from baselines.qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from dreamerv3.agent import ImagActorCritic
from dreamerv3 import embodied
from dreamerv3.embodied.core.goal_sampler import GoalSampler, GoalSamplerCyclic
from dreamerv3.embodied.core.space import AngularSpace, Space

from utils import Config, get_env, get_argv_from_config

warnings.filterwarnings("ignore", ".*truncated to dtype int32.*")


def get_args():
  parser = argparse.ArgumentParser()
  parser.add_argument('-p', '--path', type=str, required=True, help='Path to the directory containing the checkpoints.')
  parser.add_argument('-n', '--num-evals', type=int, required=True, help='Number of evaluations to run per goal.')
  args = parser.parse_args()
  return args


def eval_data_ours(path_results, number_evaluations_per_goal, path_saving_evaluations=None):

  # Create directory for saving evaluations
  if path_saving_evaluations is None:
    path_results = Path(path_results)
    path_results_name = path_results.name
    parent_directory = path_results.parents[0]
    parent_directory_name = parent_directory.name
    greatparent_directory = parent_directory.parents[0]
    path_saving_evaluations = greatparent_directory / f"{parent_directory_name}_evaluations" / path_results_name
    path_saving_evaluations.mkdir(exist_ok=True)

  config = path_results / "wandb" / "latest-run" / "files" / "config.yaml"
  checkpoint_path= str(path_results / "checkpoint.ckpt")
  with open(config, 'r') as f:
    config = yaml.safe_load(f)
  
  print(config['goal'])

  argv = [
    f"--task={config['task']['value']}",
    f"--feat={config['feat']['value']}",

    f"--run.from_checkpoint={checkpoint_path}",  # TODO: is this necessary?
    f"--goal.resolution={config['goal']['value']['resolution']}",

    f"--envs.amount=2048",
    f"--backend={config['backend']['value']}",
    ]

  # Create config
  logdir = path_saving_evaluations 
  config_defaults = embodied.Config(dreamerv3.configs["defaults"])
  config_defaults = config_defaults.update(dreamerv3.configs["brax"])
  config_defaults = config_defaults.update({
    "logdir": logdir,
    "run.train_ratio": 32,
    "run.log_every": 60,  # Seconds
    "batch_size": 16,
  })
  # argv = get_argv_from_config(config)
  config = embodied.Flags(config_defaults).parse(argv=argv)

  # Create logger
  logdir = embodied.Path(config.logdir)
  step = embodied.Counter()
  logger = embodied.Logger(step, [
    embodied.logger.TerminalOutput(),
  ])

  # Create environment
  env = get_env(config, mode="train")

  # Create agent and replay buffer
  agent = dreamerv3.Agent(env.obs_space, env.act_space, env.feat_space, step, config)

  args = embodied.Config(
    **config.run,
    logdir=config.logdir,
    batch_steps=config.batch_size * config.batch_length)

  # Create goal sampler
  resolution = ImagActorCritic.get_resolution(env.feat_space, config)
  print("env.feat_space", env.feat_space)
  space = env.feat_space["vector"]
  
  if config.feat == "angle":
    print("Angular space")
    shape_feat_space = env.feat_space["vector"].shape[0] // 2
  else:
    print("Not angular space", env.feat_space["vector"].__class__)
    shape_feat_space = env.feat_space["vector"].shape[0]
  grid_shape = (resolution,) * shape_feat_space
  goals = compute_euclidean_centroids(grid_shape, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
  goals = env.feat_space['vector'].transform(goals)
  print("ALL GOALS", goals)
  goal_sampler_cyclic = GoalSamplerCyclic(feat_space=env.feat_space, 
                                          goal_list=list(goals),
                                          number_visits_per_goal=number_evaluations_per_goal)
  embodied.run.eval_only(agent,
                         env,
                         goal_sampler=goal_sampler_cyclic,
                         period_sample_goals=float('inf'),
                         logger=logger,
                         args=args, )


def main():
  args = get_args()
  path_results = os.path.abspath(args.path)
  number_evaluations_per_goal = int(args.num_evals)

  eval(path_results, number_evaluations_per_goal)


if __name__ == "__main__":
  main()
